import copy
from time import time
import torch
import ntk_utils
from models.resnet import resnet18
import dataloader
import torchvision
from torchvision import transforms
import utils
from loss_functions import fgsm, pgd
from functorch.experimental import replace_all_batch_norm_modules_
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet',
					help='model architecture')
parser.add_argument('--dataset', default='cifar10', type=str,
					help='which dataset used to train')

parser.add_argument('--exp', default='fgsm', type=str,
					help='exp name')

parser.add_argument('--test', default=0, type=int,
					help='if on test set')

##BN
parser.add_argument('--if_bn', default=0, type=int)

args = parser.parse_args()
method = args.exp

path = './%s/%s/'%(args.dataset, args.arch) + '%s/%s.pkl%d.pkl'

epoch_list = [i for i in range(1,10)] + [i for i in range(10, 201, 10)]

print(epoch_list)

total_data_num = 500
#epoch_list = [i for i in range(0, 201, 50)]
#print(epoch_list)
#exit(0)
transform_test = transforms.Compose([torchvision.transforms.Resize((32,32)),
								   transforms.ToTensor(),
								   ])

data = dataloader.Data('cifar10', './data')
trainset, testset = data.data_loader(transform_test, transform_test)

new_testset = utils.split_dataset(testset, total_data_num)

bs = 20
test_loader = torch.utils.data.DataLoader(new_testset, batch_size=bs,
                                                shuffle=False, drop_last=False, num_workers=0)

new_trainset = utils.split_dataset(trainset, total_data_num)

train_loader = torch.utils.data.DataLoader(new_trainset, batch_size=bs,
                                                shuffle=False, drop_last=False, num_workers=0)


n = resnet18(num_classes=10)

if args.if_bn == 0:
	replace_all_batch_norm_modules_(n)
	print('BN is disable to record running mean and std')

for epoch in epoch_list:
	path_epoch = path%(method, method, epoch)
	print(path_epoch)
	model = ntk_utils.load_model(n, path_epoch)

	print('calculate NTK for clean data')
	s = time()
	jac_all = []
	y_list_clean = []
	for i, p1 in enumerate(zip(train_loader, test_loader)):
		#print(i)
		([x1, y1, idx1], [x2, y2, idx2]) = p1
		model = model.cuda()
		model.eval()
		x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()

		net, params = ntk_utils.warp_model(model)
		if args.test == 0:
			jac_sub = ntk_utils.empirical_ntk_jacobian_contraction_symmetric(net, params, x1)
			y_list_clean.append(y1)
		else:
			jac_sub = ntk_utils.empirical_ntk_jacobian_contraction_symmetric(net, params, x2)
			y_list_clean.append(y2)
		if jac_all == []:
			for i in range(len(jac_sub)):
				jac_all.append([copy.deepcopy(jac_sub[i])])
		else:
			for i in range(len(jac_all)):
				jac_all[i].append(jac_sub[i])
		#matrix_ae_clean[i * bs : (i+1) * bs, i * bs : (i+1) * bs, :, :] = sub_matrix.detach().cpu()
		del net, params
	e = time()
	print('Time Cost:', e-s)
	s = time()
	print('JAC Matrix Concating')
	for i in range(len(jac_all)):
		jac_all[i] = torch.concat(jac_all[i], dim=0)
	e = time()
	print('Time Cost:', e-s)
	s = time()
	print('Kernel Matrix Computing')
	matrix_ae_clean = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac_all, jac_all)])
	matrix_ae_clean = matrix_ae_clean.sum(0)
	e = time()
	print('Time Cost:', e-s)
	print(matrix_ae_clean.size())
	for _ in jac_all:
		del _

	y_list_clean = torch.concat(y_list_clean, dim=0).cpu()
	if args.test == 0:
		torch.save(y_list_clean, './cifar10/resnet/%s/label_ae_clean%d_clean.pt'%(method, epoch))
	else:
		torch.save(y_list_clean, './cifar10/resnet/%s/label_ae_clean%d_clean_test.pt'%(method, epoch))


	s = time()
	print('calculate NTK for pgd data')
	jac_all = []
	y_list_clean = []
	y_list_ae = []
	for i, p1 in enumerate(zip(train_loader, test_loader)):
		#print(i)
		([x1, y1, idx1], [x2, y2, idx2]) = p1
		model = model.cuda()
		model.eval()
		x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()

		if args.test == 0:
			x_train_pgd = pgd(model, x1, y1)
			y_list_clean.append(y1)
			y_pre = n(x_train_pgd)
			_, predicted = torch.max(y_pre.data, 1)
			y_list_ae.append(predicted)
		else:
			x_train_pgd = pgd(model, x2, y2)
			y_list_clean.append(y2)
			y_pre = n(x_train_pgd)
			_, predicted = torch.max(y_pre.data, 1)
			y_list_ae.append(predicted)


		net, params = ntk_utils.warp_model(model)
		jac_sub = ntk_utils.empirical_ntk_jacobian_contraction_symmetric(net, params, x_train_pgd)
		if jac_all == []:
			for i in range(len(jac_sub)):
				jac_all.append([copy.deepcopy(jac_sub[i])])
		else:
			for i in range(len(jac_all)):
				jac_all[i].append(jac_sub[i])
		# matrix_ae_clean[i * bs : (i+1) * bs, i * bs : (i+1) * bs, :, :] = sub_matrix.detach().cpu()
		del net, params
	e = time()
	print('Time Cost:', e-s)
	s = time()
	print('JAC Matrix Concating')
	for i in range(len(jac_all)):
		jac_all[i] = torch.concat(jac_all[i], dim=0)
	e = time()
	print('Time Cost:', e-s)
	s = time()
	print('Kernel Matrix Computing')
	matrix_ae_pgd = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac_all, jac_all)])
	matrix_ae_pgd = matrix_ae_pgd.sum(0)
	e = time()
	print('Time Cost:', e-s)
	print(matrix_ae_pgd.size())
	for _ in jac_all:
		del _

	y_list_clean = torch.concat(y_list_clean, dim=0).cpu()
	y_list_ae = torch.concat(y_list_ae, dim=0).cpu()
	if args.test == 0:
		torch.save(y_list_clean, './cifar10/resnet/%s/label_ae_pgd%d_clean.pt'%(method, epoch))
		torch.save(y_list_ae, './cifar10/resnet/%s/label_ae_pgd%d_ae.pt'%(method, epoch))
	else:
		torch.save(y_list_clean, './cifar10/resnet/%s/label_ae_pgd%d_clean_test.pt'%(method, epoch))
		torch.save(y_list_ae, './cifar10/resnet/%s/label_ae_pgd%d_ae_test.pt' % (method, epoch))


	if args.test == 0:
		torch.save(matrix_ae_clean, './cifar10/resnet/%s/matrix_ae_clean%d.pt'%(method, epoch))
		#torch.save(matrix_ae_fgsm, './cifar10/resnet/%s/matrix_ae_fgsm%d.pt'%(method, epoch))
		torch.save(matrix_ae_pgd, './cifar10/resnet/%s/matrix_ae_pgd%d.pt'%(method, epoch))
	else:
		torch.save(matrix_ae_clean, './cifar10/resnet/%s/matrix_ae_clean%d_test.pt'%(method, epoch))
		#torch.save(matrix_ae_fgsm, './cifar10/resnet/%s/matrix_ae_fgsm%d_test.pt'%(method, epoch))
		torch.save(matrix_ae_pgd, './cifar10/resnet/%s/matrix_ae_pgd%d_test.pt'%(method, epoch))



